{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# tf.data API使用"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.0.0\n",
      "sys.version_info(major=3, minor=6, micro=10, releaselevel='final', serial=0)\n",
      "matplotlib 3.1.2\n",
      "numpy 1.18.1\n",
      "pandas 0.25.3\n",
      "sklearn 0.22.1\n",
      "tensorflow 2.0.0\n",
      "tensorflow_core.keras 2.2.4-tf\n"
     ]
    }
   ],
   "source": [
    "# 导入\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import sklearn\n",
    "import pandas as pd\n",
    "import os\n",
    "import sys\n",
    "import time\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "\n",
    "print(tf.__version__)\n",
    "print(sys.version_info)\n",
    "for module in mpl,np,pd,sklearn,tf,keras:\n",
    "    print(module.__name__,module.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 利用内存中的数据构建Dataset\n",
    "\n",
    "#### numpy生成数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<TensorSliceDataset shapes: (), types: tf.int32>\n"
     ]
    }
   ],
   "source": [
    "# 从内存中构建数据集\n",
    "dataset = tf.data.Dataset.from_tensor_slices(np.arange(10))\n",
    "print(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(0, shape=(), dtype=int32)\n",
      "tf.Tensor(1, shape=(), dtype=int32)\n",
      "tf.Tensor(2, shape=(), dtype=int32)\n",
      "tf.Tensor(3, shape=(), dtype=int32)\n",
      "tf.Tensor(4, shape=(), dtype=int32)\n",
      "tf.Tensor(5, shape=(), dtype=int32)\n",
      "tf.Tensor(6, shape=(), dtype=int32)\n",
      "tf.Tensor(7, shape=(), dtype=int32)\n",
      "tf.Tensor(8, shape=(), dtype=int32)\n",
      "tf.Tensor(9, shape=(), dtype=int32)\n"
     ]
    }
   ],
   "source": [
    "# 对dataset遍历\n",
    "for item in dataset:\n",
    "    print(item)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1.repeat操作，每一个epoch就是遍历一遍数据\n",
    "\n",
    "2.get batch获取batch size个数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor([0 1 2 3 4 5 6], shape=(7,), dtype=int32)\n",
      "tf.Tensor([7 8 9 0 1 2 3], shape=(7,), dtype=int32)\n",
      "tf.Tensor([4 5 6 7 8 9 0], shape=(7,), dtype=int32)\n",
      "tf.Tensor([1 2 3 4 5 6 7], shape=(7,), dtype=int32)\n",
      "tf.Tensor([8 9], shape=(2,), dtype=int32)\n"
     ]
    }
   ],
   "source": [
    "dataset = dataset.repeat(3).batch(7)\n",
    "for item in dataset:\n",
    "    print(item)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "3.interleave：对现有的dataset中的每个数据做处理，处理后产生新的结果，interleave将结果合并起来形成新的数据集\n",
    "\n",
    "常见的case：文件dataset -> 具体数据集\n",
    "\n",
    "现有的dataset存储的是一系列文件的文件名，使用interleave做变化，遍历文件名数据集中的所有元素--文件名，把文件名对应的文件内容读取出来，这样的文件名就形成了新的数据集，interleave再把新的数据集合并起来，成为总的大数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(0, shape=(), dtype=int32)\n",
      "tf.Tensor(1, shape=(), dtype=int32)\n",
      "tf.Tensor(2, shape=(), dtype=int32)\n",
      "tf.Tensor(3, shape=(), dtype=int32)\n",
      "tf.Tensor(4, shape=(), dtype=int32)\n",
      "tf.Tensor(7, shape=(), dtype=int32)\n",
      "tf.Tensor(8, shape=(), dtype=int32)\n",
      "tf.Tensor(9, shape=(), dtype=int32)\n",
      "tf.Tensor(0, shape=(), dtype=int32)\n",
      "tf.Tensor(1, shape=(), dtype=int32)\n",
      "tf.Tensor(4, shape=(), dtype=int32)\n",
      "tf.Tensor(5, shape=(), dtype=int32)\n",
      "tf.Tensor(6, shape=(), dtype=int32)\n",
      "tf.Tensor(7, shape=(), dtype=int32)\n",
      "tf.Tensor(8, shape=(), dtype=int32)\n",
      "tf.Tensor(1, shape=(), dtype=int32)\n",
      "tf.Tensor(2, shape=(), dtype=int32)\n",
      "tf.Tensor(3, shape=(), dtype=int32)\n",
      "tf.Tensor(4, shape=(), dtype=int32)\n",
      "tf.Tensor(5, shape=(), dtype=int32)\n",
      "tf.Tensor(8, shape=(), dtype=int32)\n",
      "tf.Tensor(9, shape=(), dtype=int32)\n",
      "tf.Tensor(5, shape=(), dtype=int32)\n",
      "tf.Tensor(6, shape=(), dtype=int32)\n",
      "tf.Tensor(2, shape=(), dtype=int32)\n",
      "tf.Tensor(3, shape=(), dtype=int32)\n",
      "tf.Tensor(9, shape=(), dtype=int32)\n",
      "tf.Tensor(0, shape=(), dtype=int32)\n",
      "tf.Tensor(6, shape=(), dtype=int32)\n",
      "tf.Tensor(7, shape=(), dtype=int32)\n"
     ]
    }
   ],
   "source": [
    "dataset2 = dataset.interleave(\n",
    "    # map_fn 做什么样的变化\n",
    "    # cycle_length 并行的处理多少个元素\n",
    "    # block_length 从上面的变化中每次取多少个元素出来\n",
    "    lambda v: tf.data.Dataset.from_tensor_slices(v),\n",
    "    cycle_length=5,\n",
    "    block_length=5, # 达到均匀混合的效果\n",
    ")\n",
    "for item in dataset2:\n",
    "    print(item)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### python原生的数据类型生成数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<TensorSliceDataset shapes: ((2,), ()), types: (tf.int32, tf.string)>\n",
      "[1 2] b'cat'\n",
      "[3 4] b'dog'\n",
      "[5 6] b'fox'\n"
     ]
    }
   ],
   "source": [
    "x = np.array([[1, 2], [3, 4], [5 ,6]])\n",
    "y = np.array(['cat', 'dog', 'fox'])\n",
    "dataset3 = tf.data.Dataset.from_tensor_slices((x,y))\n",
    "print(dataset3)\n",
    "\n",
    "for item_x,item_y in dataset3:\n",
    "    print(item_x.numpy(), item_y.numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'feature': <tf.Tensor: id=98, shape=(2,), dtype=int32, numpy=array([1, 2])>, 'label': <tf.Tensor: id=99, shape=(), dtype=string, numpy=b'cat'>}\n",
      "[1 2] b'cat'\n",
      "{'feature': <tf.Tensor: id=100, shape=(2,), dtype=int32, numpy=array([3, 4])>, 'label': <tf.Tensor: id=101, shape=(), dtype=string, numpy=b'dog'>}\n",
      "[3 4] b'dog'\n",
      "{'feature': <tf.Tensor: id=102, shape=(2,), dtype=int32, numpy=array([5, 6])>, 'label': <tf.Tensor: id=103, shape=(), dtype=string, numpy=b'fox'>}\n",
      "[5 6] b'fox'\n"
     ]
    }
   ],
   "source": [
    "dataset4 = tf.data.Dataset.from_tensor_slices({\"feature\": x, \"label\": y})\n",
    "for item in dataset4:\n",
    "    print(item)\n",
    "    print(item['feature'].numpy(), item['label'].numpy())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 从CSV文件中读取数据生成数据集\n",
    "在文件中读取文件内容，构成Dataset，进行使用，以及在keras中进行使用\n",
    "#### 生成csv文件"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 导入数据集 房价预测\n",
    "from sklearn.datasets import fetch_california_housing\n",
    "\n",
    "housing = fetch_california_housing()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(11610, 8) (11610,)\n",
      "(3870, 8) (3870,)\n",
      "(5160, 8) (5160,)\n"
     ]
    }
   ],
   "source": [
    "# 切分数据集\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "x_train_all,x_test,y_train_all,y_test = train_test_split(\n",
    "    housing.data,housing.target,random_state=7)\n",
    "x_train,x_valid,y_train,y_valid = train_test_split(\n",
    "    x_train_all,y_train_all,random_state=11)\n",
    "\n",
    "print(x_train.shape,y_train.shape)\n",
    "print(x_valid.shape,y_valid.shape)\n",
    "print(x_test.shape,y_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 对数据进行归一化\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "scaler = StandardScaler()\n",
    "x_train_scaled = scaler.fit_transform(x_train)\n",
    "x_valid_scaled = scaler.transform(x_valid)\n",
    "x_test_scaled = scaler.transform(x_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import shutil\n",
    "# 定义生成csv文件的文件夹\n",
    "output_dir = os.path.join('generate_csv')\n",
    "if not os.path.exists(output_dir):\n",
    "    os.mkdir(output_dir)\n",
    "else:\n",
    "    shutil.rmtree(output_dir)\n",
    "    os.mkdir(output_dir)\n",
    "    \n",
    "# 将x和y按行进行merge\n",
    "train_data = np.c_[x_train_scaled, y_train]\n",
    "valid_data = np.c_[x_valid_scaled, y_valid]\n",
    "test_data = np.c_[x_test_scaled, y_test]\n",
    "\n",
    "# 定义header信息\n",
    "header_cols = housing.feature_names + ['MidianHouseValue']\n",
    "header_str = ','.join(header_cols)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 将数据保存到csv文件\n",
    "def save_to_csv(output_dir, data, name_prefix, header=None, n_parts=10):\n",
    "    path_format = os.path.join(output_dir, '{}_{:02d}.csv') # 定义文件名格式\n",
    "    filenames = []\n",
    "    \n",
    "    # 将数据分成n_parts个部分，获得每组索引和值，遍历\n",
    "    for file_idx, row_indices in enumerate(\n",
    "        np.array_split(np.arange(len(data)), n_parts)):\n",
    "        part_csv = path_format.format(name_prefix, file_idx)\n",
    "        filenames.append(part_csv)\n",
    "        with open(part_csv, 'wt', encoding='utf-8') as f:\n",
    "            if header is not None:\n",
    "                f.write(header + '\\n')\n",
    "            # 对每一组进行遍历，将取到的col字符串化，再用逗号连接\n",
    "            for row_index in row_indices:\n",
    "                f.write(','.join([repr(col) for col in data[row_index]]))\n",
    "                f.write('\\n')\n",
    "    # 返回所有的文件\n",
    "    return filenames\n",
    "\n",
    "train_filenames = save_to_csv(output_dir,train_data,'train',header_str,n_parts=20)\n",
    "valid_filenames = save_to_csv(output_dir,valid_data,'valid',header_str,n_parts=10)\n",
    "test_filenames = save_to_csv(output_dir,test_data,'test',header_str,n_parts=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 解析csv文件"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train filenames:\n",
      "['generate_csv\\\\train_00.csv',\n",
      " 'generate_csv\\\\train_01.csv',\n",
      " 'generate_csv\\\\train_02.csv',\n",
      " 'generate_csv\\\\train_03.csv',\n",
      " 'generate_csv\\\\train_04.csv',\n",
      " 'generate_csv\\\\train_05.csv',\n",
      " 'generate_csv\\\\train_06.csv',\n",
      " 'generate_csv\\\\train_07.csv',\n",
      " 'generate_csv\\\\train_08.csv',\n",
      " 'generate_csv\\\\train_09.csv',\n",
      " 'generate_csv\\\\train_10.csv',\n",
      " 'generate_csv\\\\train_11.csv',\n",
      " 'generate_csv\\\\train_12.csv',\n",
      " 'generate_csv\\\\train_13.csv',\n",
      " 'generate_csv\\\\train_14.csv',\n",
      " 'generate_csv\\\\train_15.csv',\n",
      " 'generate_csv\\\\train_16.csv',\n",
      " 'generate_csv\\\\train_17.csv',\n",
      " 'generate_csv\\\\train_18.csv',\n",
      " 'generate_csv\\\\train_19.csv']\n",
      "valid filenames:\n",
      "['generate_csv\\\\valid_00.csv',\n",
      " 'generate_csv\\\\valid_01.csv',\n",
      " 'generate_csv\\\\valid_02.csv',\n",
      " 'generate_csv\\\\valid_03.csv',\n",
      " 'generate_csv\\\\valid_04.csv',\n",
      " 'generate_csv\\\\valid_05.csv',\n",
      " 'generate_csv\\\\valid_06.csv',\n",
      " 'generate_csv\\\\valid_07.csv',\n",
      " 'generate_csv\\\\valid_08.csv',\n",
      " 'generate_csv\\\\valid_09.csv']\n",
      "test filenames:\n",
      "['generate_csv\\\\test_00.csv',\n",
      " 'generate_csv\\\\test_01.csv',\n",
      " 'generate_csv\\\\test_02.csv',\n",
      " 'generate_csv\\\\test_03.csv',\n",
      " 'generate_csv\\\\test_04.csv',\n",
      " 'generate_csv\\\\test_05.csv',\n",
      " 'generate_csv\\\\test_06.csv',\n",
      " 'generate_csv\\\\test_07.csv',\n",
      " 'generate_csv\\\\test_08.csv',\n",
      " 'generate_csv\\\\test_09.csv']\n"
     ]
    }
   ],
   "source": [
    "# 打印文件名\n",
    "import pprint\n",
    "print('train filenames:')\n",
    "pprint.pprint(train_filenames)\n",
    "print('valid filenames:')\n",
    "pprint.pprint(valid_filenames)\n",
    "print('test filenames:')\n",
    "pprint.pprint(test_filenames)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "解析csv步骤：\n",
    "1. filename -> dataset\n",
    "2. read file -> dataset -> datasets -> merge\n",
    "3. parse csv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(b'generate_csv\\\\train_14.csv', shape=(), dtype=string)\n",
      "tf.Tensor(b'generate_csv\\\\train_10.csv', shape=(), dtype=string)\n",
      "tf.Tensor(b'generate_csv\\\\train_16.csv', shape=(), dtype=string)\n",
      "tf.Tensor(b'generate_csv\\\\train_11.csv', shape=(), dtype=string)\n",
      "tf.Tensor(b'generate_csv\\\\train_07.csv', shape=(), dtype=string)\n",
      "tf.Tensor(b'generate_csv\\\\train_17.csv', shape=(), dtype=string)\n",
      "tf.Tensor(b'generate_csv\\\\train_08.csv', shape=(), dtype=string)\n",
      "tf.Tensor(b'generate_csv\\\\train_04.csv', shape=(), dtype=string)\n",
      "tf.Tensor(b'generate_csv\\\\train_12.csv', shape=(), dtype=string)\n",
      "tf.Tensor(b'generate_csv\\\\train_19.csv', shape=(), dtype=string)\n",
      "tf.Tensor(b'generate_csv\\\\train_09.csv', shape=(), dtype=string)\n",
      "tf.Tensor(b'generate_csv\\\\train_01.csv', shape=(), dtype=string)\n",
      "tf.Tensor(b'generate_csv\\\\train_03.csv', shape=(), dtype=string)\n",
      "tf.Tensor(b'generate_csv\\\\train_15.csv', shape=(), dtype=string)\n",
      "tf.Tensor(b'generate_csv\\\\train_00.csv', shape=(), dtype=string)\n",
      "tf.Tensor(b'generate_csv\\\\train_05.csv', shape=(), dtype=string)\n",
      "tf.Tensor(b'generate_csv\\\\train_13.csv', shape=(), dtype=string)\n",
      "tf.Tensor(b'generate_csv\\\\train_06.csv', shape=(), dtype=string)\n",
      "tf.Tensor(b'generate_csv\\\\train_02.csv', shape=(), dtype=string)\n",
      "tf.Tensor(b'generate_csv\\\\train_18.csv', shape=(), dtype=string)\n"
     ]
    }
   ],
   "source": [
    "# 1.将文件名构成数据集\n",
    "# list_files专门处理文件名，会把文件名生成Dataset\n",
    "filename_dataset = tf.data.Dataset.list_files(train_filenames) # 将训练集所有文件构成dataset\n",
    "for filename in filename_dataset:\n",
    "    print(filename)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "b'0.6303435674178064,1.874166156711919,-0.06713214279531016,-0.12543366804152128,-0.19737553788322462,-0.022722631725889016,-0.692407235065288,0.7265233438487496,2.419'\n",
      "b'-1.0591781535672364,1.393564736946074,-0.026331968874673636,-0.11006759528831847,-0.6138198966579805,-0.09695934953589447,0.3247131133362288,-0.037477245413977976,0.672'\n",
      "b'-0.8219588176953616,1.874166156711919,0.18212349433218608,-0.03170019246279883,-0.6011178900722581,-0.14337494105109344,1.0852205298015787,-0.8613994495208361,1.054'\n",
      "b'0.6363646332204844,-1.0895425985107923,0.09260902815633619,-0.20538124656801682,1.2025670451003232,-0.03630122549633783,-0.6784101660505877,0.182235342347858,2.429'\n",
      "b'-1.1157655153587753,0.9930635538078697,-0.33419201318312125,-0.0653521844775239,-0.3289320346639209,0.04343065774347637,-0.12785878480573185,0.30707203993980686,0.524'\n",
      "b'1.6312258686346301,0.3522616607867429,0.04080576110152256,-0.1408895163348976,-0.4632103899987006,-0.06751623819156843,-0.8277122355407183,0.5966931783531273,3.376'\n",
      "b'-0.2223565745313433,1.393564736946074,0.02991299565857307,0.0801452044790158,-0.509481985418118,-0.06238599304952824,-0.86503775291325,0.8613469772480595,2.0'\n",
      "b'-0.46794146200516895,-0.9293421252555106,0.11909925912590703,-0.060470113038678074,0.30344643606811583,-0.021851890609536125,1.873722084296329,-1.0411642940532422,1.012'\n",
      "b'-1.0635474225567902,1.874166156711919,-0.49344892844525906,-0.06962612737313081,-0.273587577397559,-0.13419514417565354,1.0338979434143465,-1.3457658361775973,1.982'\n",
      "b'-0.9760554752293826,1.2333642636907922,-0.3909986321352606,-0.15728481711770903,-0.8261248638764835,-0.14088780945051624,1.360496220424008,-0.9512818717870428,1.136'\n",
      "b'0.04049225382803661,-0.6890414153725881,-0.44379851741607473,0.022374585146687852,-0.22187226486997497,-0.1482850314959248,-0.8883662012710817,0.6366409215825501,2.852'\n",
      "b'0.199384450496934,1.0731637904355105,-0.19840853933562783,-0.29328906965393414,-0.07852104768825069,0.018804888420646343,0.8006134598360177,-1.1510205879341566,1.99'\n",
      "b'-0.060214068004363165,0.7527628439249472,0.0835940301935345,-0.06250122441959183,-0.03497131082291674,-0.026442380178345683,1.0712234607868782,-1.3707331756959855,1.651'\n",
      "b'-0.9974222662636643,1.2333642636907922,-0.7577192870888144,-0.011109251557751528,-0.23003784053222506,0.05487422342718872,-0.757726890467217,0.7065494722340417,1.739'\n",
      "b'0.6289049056773436,-0.44874070548966555,0.011390452394941722,-0.21388453904713714,0.13196934716086342,-0.08002252121823207,-0.883700511599516,0.8813208488627673,2.522'\n"
     ]
    }
   ],
   "source": [
    "# 2.读取文件里的内容，构成新的dataset，再合并\n",
    "n_readers = 5\n",
    "dataset = filename_dataset.interleave(\n",
    "    lambda filename: tf.data.TextLineDataset(filename).skip(1),# 按行读取文件内容，跳过第一行header\n",
    "    cycle_length=n_readers\n",
    ")\n",
    "# 读取15条数据\n",
    "for line in dataset.take(15):\n",
    "    print(line.numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[<tf.Tensor: id=203, shape=(), dtype=int32, numpy=1>, <tf.Tensor: id=204, shape=(), dtype=int32, numpy=2>, <tf.Tensor: id=205, shape=(), dtype=float32, numpy=3.0>, <tf.Tensor: id=206, shape=(), dtype=string, numpy=b'4'>, <tf.Tensor: id=207, shape=(), dtype=float32, numpy=5.0>]\n"
     ]
    }
   ],
   "source": [
    "# 3.解析csv\n",
    "# tf.io.decode_csv(str,record_defaults)\n",
    "\n",
    "sample_str = '1,2,3,4,5'\n",
    "# 设置记录的默认值和数据类型\n",
    "record_defaults = [\n",
    "    tf.constant(0, dtype=tf.int32),\n",
    "    0,\n",
    "    np.nan,\n",
    "    'hello',\n",
    "    tf.constant([])\n",
    "]\n",
    "parsed_fields = tf.io.decode_csv(sample_str, record_defaults)\n",
    "print(parsed_fields)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Field 4 is required but missing in record 0! [Op:DecodeCSV]\n"
     ]
    }
   ],
   "source": [
    "try:\n",
    "    parsed_fields = tf.io.decode_csv(',,,,', record_defaults)\n",
    "except tf.errors.InvalidArgumentError as ex:\n",
    "    print(ex)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Expect 5 fields but have 7 in record 0 [Op:DecodeCSV]\n"
     ]
    }
   ],
   "source": [
    "try:\n",
    "    parsed_fields = tf.io.decode_csv('1,2,3,4,5,6,7', record_defaults)\n",
    "except tf.errors.InvalidArgumentError as ex:\n",
    "    print(ex)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 读取csv文件"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(<tf.Tensor: id=296, shape=(8,), dtype=float32, numpy=\n",
       " array([ 0.6289049 , -0.4487407 ,  0.01139045, -0.21388453,  0.13196935,\n",
       "        -0.08002252, -0.8837005 ,  0.88132083], dtype=float32)>,\n",
       " <tf.Tensor: id=297, shape=(1,), dtype=float32, numpy=array([2.522], dtype=float32)>)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 将csv中的一行转化\n",
    "def parse_csv_line(line, n_fields=9):\n",
    "    defs = [tf.constant(np.nan)] * n_fields\n",
    "    parsed_fields = tf.io.decode_csv(line, record_defaults=defs)\n",
    "    x = tf.stack(parsed_fields[0:-1]) # 将前n-1条数据转化为x向量\n",
    "    y = tf.stack(parsed_fields[-1:]) # 将最后一条转化为y向量\n",
    "    return x, y\n",
    "\n",
    "parse_csv_line(b'0.6289049056773436,-0.44874070548966555,0.011390452394941722,-0.21388453904713714,0.13196934716086342,-0.08002252121823207,-0.883700511599516,0.8813208488627673,2.522',\n",
    "               n_fields=9)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "完整流程实现csv读取、解析"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x:\n",
      "<tf.Tensor: id=381, shape=(3, 8), dtype=float32, numpy=\n",
      "array([[ 0.63636464, -1.0895426 ,  0.09260903, -0.20538124,  1.2025671 ,\n",
      "        -0.03630123, -0.6784102 ,  0.18223535],\n",
      "       [ 0.63034356,  1.8741661 , -0.06713215, -0.12543367, -0.19737554,\n",
      "        -0.02272263, -0.69240725,  0.72652334],\n",
      "       [-1.0591781 ,  1.3935647 , -0.02633197, -0.1100676 , -0.6138199 ,\n",
      "        -0.09695935,  0.3247131 , -0.03747724]], dtype=float32)>\n",
      "y:\n",
      "<tf.Tensor: id=382, shape=(3, 1), dtype=float32, numpy=\n",
      "array([[2.429],\n",
      "       [2.419],\n",
      "       [0.672]], dtype=float32)>\n",
      "x:\n",
      "<tf.Tensor: id=383, shape=(3, 8), dtype=float32, numpy=\n",
      "array([[ 0.48530516, -0.8492419 , -0.06530126, -0.02337966,  1.4974351 ,\n",
      "        -0.07790658, -0.90236324,  0.78145146],\n",
      "       [ 0.40127665, -0.92934215, -0.0533305 , -0.18659453,  0.65456617,\n",
      "         0.02643447,  0.9312528 , -1.4406418 ],\n",
      "       [-1.0635474 ,  1.8741661 , -0.49344894, -0.06962613, -0.27358758,\n",
      "        -0.13419515,  1.033898  , -1.3457658 ]], dtype=float32)>\n",
      "y:\n",
      "<tf.Tensor: id=384, shape=(3, 1), dtype=float32, numpy=\n",
      "array([[2.956],\n",
      "       [2.512],\n",
      "       [1.982]], dtype=float32)>\n"
     ]
    }
   ],
   "source": [
    "def csv_reader_dateset(filenames, n_readers=5, batch_size=32,\n",
    "                       n_parse_threads=5, shuffle_buffer_size=10000):\n",
    "    # 1.filename -> filename dataset\n",
    "    dataset = tf.data.Dataset.list_files(filenames)\n",
    "    dataset = dataset.repeat() # 将数据重复无限次\n",
    "    # 2.filename dataset -> text dataset 一对多的关系\n",
    "    dataset = dataset.interleave(\n",
    "        lambda filename : tf.data.TextLineDataset(filename).skip(1),\n",
    "        cycle_length=n_readers\n",
    "    )\n",
    "    dataset.shuffle(shuffle_buffer_size) # 将数据进行混排\n",
    "    # 3.parse csv 一对一关系\n",
    "    dataset = dataset.map(parse_csv_line, num_parallel_calls=n_parse_threads)\n",
    "    dataset = dataset.batch(batch_size)\n",
    "    return dataset\n",
    "\n",
    "# test\n",
    "train_set = csv_reader_dateset(train_filenames, batch_size=3)\n",
    "for x_batch, y_batch in train_set.take(2):\n",
    "    print('x:')\n",
    "    pprint.pprint(x_batch)\n",
    "    print('y:')\n",
    "    pprint.pprint(y_batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 32\n",
    "train_set = csv_reader_dateset(train_filenames, batch_size=batch_size)\n",
    "valid_set = csv_reader_dateset(valid_filenames, batch_size=batch_size)\n",
    "test_set = csv_reader_dateset(test_filenames, batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 与keras集成"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "dense (Dense)                (None, 30)                270       \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              (None, 1)                 31        \n",
      "=================================================================\n",
      "Total params: 301\n",
      "Trainable params: 301\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "Train for 348 steps, validate for 120 steps\n",
      "Epoch 1/100\n",
      "348/348 [==============================] - 2s 6ms/step - loss: 2.5439 - val_loss: 0.7588\n",
      "Epoch 2/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.6221 - val_loss: 0.5885\n",
      "Epoch 3/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.5125 - val_loss: 0.5060\n",
      "Epoch 4/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.4387 - val_loss: 0.4558\n",
      "Epoch 5/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.4204 - val_loss: 0.4272\n",
      "Epoch 6/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.3929 - val_loss: 0.4170\n",
      "Epoch 7/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.4001 - val_loss: 0.4136\n",
      "Epoch 8/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.3814 - val_loss: 0.4015\n",
      "Epoch 9/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.3867 - val_loss: 0.3966\n",
      "Epoch 10/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.3673 - val_loss: 0.3942\n",
      "Epoch 11/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.3720 - val_loss: 0.3929\n",
      "Epoch 12/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.3738 - val_loss: 0.3842\n",
      "Epoch 13/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.3624 - val_loss: 0.3747\n",
      "Epoch 14/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.3619 - val_loss: 0.3777\n",
      "Epoch 15/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.3631 - val_loss: 0.3721\n",
      "Epoch 16/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.3498 - val_loss: 0.3651\n",
      "Epoch 17/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.3548 - val_loss: 0.3783\n",
      "Epoch 18/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.3601 - val_loss: 0.3639\n",
      "Epoch 19/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.3451 - val_loss: 0.3623\n",
      "Epoch 20/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.3449 - val_loss: 0.3586\n",
      "Epoch 21/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.3454 - val_loss: 0.3582\n",
      "Epoch 22/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.3486 - val_loss: 0.3537\n",
      "Epoch 23/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.3408 - val_loss: 0.3514\n",
      "Epoch 24/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.3362 - val_loss: 0.3512\n",
      "Epoch 25/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.3349 - val_loss: 0.3490\n"
     ]
    }
   ],
   "source": [
    "model = keras.models.Sequential([\n",
    "    keras.layers.Dense(30, activation='relu',input_shape=[8]),\n",
    "    keras.layers.Dense(1)\n",
    "])\n",
    "model.summary()\n",
    "\n",
    "model.compile(loss='mean_squared_error',optimizer='adam')\n",
    "\n",
    "callbacks = [keras.callbacks.EarlyStopping(patience=5,min_delta=1e-2)]\n",
    "history = model.fit(train_set,\n",
    "                    validation_data = valid_set,\n",
    "                    steps_per_epoch = 11160 // batch_size,\n",
    "                    validation_steps = 3870 // batch_size,\n",
    "                    epochs = 100,\n",
    "                    callbacks = callbacks)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "161/161 [==============================] - 1s 3ms/step - loss: 0.3524\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.3523958790931642"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.evaluate(test_set, steps = 5160 // batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
